package com.candy.common.filter;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import com.candy.common.constant.CommonConstant;
import jakarta.servlet.*;
import jakarta.servlet.http.HttpServletRequest;
import jodd.net.HttpMethod;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.AntPathMatcher;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;

/**
 * XSS过滤器
 *
 * @author rong xi
 * @version 1.0
 * @date 2023/11/07 14:16
 */
@Slf4j
public class XssFilter implements Filter {
    /**
     * xss排除url
     */
    private List<String> xssExcludes;
    /**
     * 路径匹配
     */
    private final AntPathMatcher pathMatcher = new AntPathMatcher();

    /**
     * 初始化
     * @param filterConfig 初始化配置数据
     *
     */
    @Override
    public void init(FilterConfig filterConfig){
        xssExcludes = Optional.ofNullable(filterConfig.getInitParameter(CommonConstant.XSS_EXCLUDES_PARAM))
                .filter(StrUtil::isNotBlank)
                .map(Collections::singletonList)
                .orElse(new ArrayList<>());
    }

    /**
     * 过滤处理
     *
     * @param request  请求
     * @param response 响应
     * @param chain    过滤链
     * @throws IOException 异常
     * @throws ServletException 异常
     */
    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest req = (HttpServletRequest) request;
        //判断是否需要XSS过滤
        if (needXssCheck(req)) {
            log.info("XSS过滤路径:{}:{}",req.getMethod(),req.getRequestURI());
            XssHttpServletRequestWrapper xssRequest = new XssHttpServletRequestWrapper((HttpServletRequest) request);
            chain.doFilter(xssRequest, response);
        }else {
            chain.doFilter(request, response);
        }
    }

    /**
     * 检查此次请求是否需要XSS过滤器
     *
     * @param request 请求
     * @return 是否执行过滤
     */
    private boolean needXssCheck(HttpServletRequest request) {
        String url = request.getServletPath();
        String method = request.getMethod();
        // GET 不过滤
        if (method == null || HttpMethod.GET.name().equalsIgnoreCase(method) || HttpMethod.OPTIONS.name().equalsIgnoreCase(method)) {
            return false;
        }
        //任何一个为空 不过滤
        if (CollectionUtil.isEmpty(xssExcludes)) {
            return true;
        }

        for (String pattern : xssExcludes) {
            if (pathMatcher.match(pattern, url)) {
                return false;
            }
        }
        return true;
    }

}